{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用Align-Anything框架进行文本-图像到文本的SFT训练\n",
    "\n",
    "这个教程介绍如何使用Align-Anything框架对多模态模型进行监督微调(SFT)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 准备工作\n",
    "\n",
    "- Align-Anything已安装。\n",
    "- 一个文本-图像到文本数据集，在本教程中，我们使用[PKU-Alignment/Align-Anything-TI2T-Instruction-100K](https://huggingface.co/datasets/PKU-Alignment/Align-Anything-TI2T-Instruction-100K)数据集。\n",
    "- LLaVA-1.5-7b模型，可以从[这里](https://huggingface.co/llava-hf/llava-1.5-7b-hf)下载。\n",
    "- 一个至少有70GB内存的GPU。\n",
    "> 较低内存的GPU也是可能的。我们将在未来提供使用较小模型的脚本。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载预训练模型\n",
    "\n",
    "首先，我们需要加载一个预训练模型。我们将使用LLaVA-1.5-7b模型，这是一个能够理解文本和图像、生成文本的多模态模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from align_anything.models.pretrained_model import load_pretrained_models\n",
    "from align_anything.utils.multi_process import get_current_device\n",
    "\n",
    "# Load the pre-trained model, tokenizer, and processor\n",
    "model, tokenizer, processor = load_pretrained_models(\n",
    "    \"/path/to/llava-1.5-7b-hf\",  # Replace with your model path\n",
    "    model_max_length=4096,\n",
    "    padding_side='right',\n",
    "    trust_remote_code=True,\n",
    "    modality=['image'],\n",
    ")\n",
    "\n",
    "# Move the model to the available device (GPU if available)\n",
    "model = model.to(get_current_device())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 设置优化器\n",
    "\n",
    "对于微调，我们将使用AdamW优化器，这是一个流行的选择。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "# Initialize the optimizer with a learning rate of 1e-5\n",
    "optimizer = AdamW(model.parameters(), lr=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  配置 Chat Template\n",
    "\n",
    "Align-Anything使用 Chat Template 来格式化模型的输入。这里，我们使用 *AA_TI2T* Chat Template，这是为 Align-Anything-TI2T-Instruction-100K 设计的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from align_anything.configs.template import ChatTemplate\n",
    "\n",
    "train_template = ChatTemplate(\n",
    "    formatter=processor,\n",
    "    template=\"AA_TI2T\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chat Template 是一个数据集特定的格式化工具，将输入数据映射到模型的输入格式。以下是 *AA_TI2T* 模板的详细信息：\n",
    "\n",
    "```python\n",
    "@register_template('AA_TI2T')\n",
    "class AA_TI2T(BaseFormatter):\n",
    "    system_prompt: str = ''\n",
    "    \n",
    "    def format_supervised_sample(\n",
    "        self, raw_sample: dict[str, Any]\n",
    "    ) -> tuple[list[dict[str, Any]], dict[str, Any]]:\n",
    "        prompt = raw_sample['prompt']\n",
    "        answer = raw_sample['response']\n",
    "        image = raw_sample['image'].convert('RGBA')\n",
    "\n",
    "        return [\n",
    "            {\n",
    "                'role': 'user',\n",
    "                'content': [\n",
    "                    {'type': 'image'},\n",
    "                    {'type': 'text', 'text': prompt},\n",
    "                ],\n",
    "            },\n",
    "            {'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]},\n",
    "        ], {'image': image}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建数据集\n",
    "\n",
    "我们将使用 SupervisedDataset 类来加载我们的文本-图像到文本数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from align_anything.datasets.text_image_to_text import SupervisedDataset\n",
    "\n",
    "# Initialize the training dataset\n",
    "train_dataset = SupervisedDataset(\n",
    "    path=\"/path/to/Align-Anything-TI2T-Instruction-100K\",  # Replace with your dataset path\n",
    "    template=train_template,\n",
    "    tokenizer=tokenizer,\n",
    "    processor=processor,\n",
    "    split=\"train\",\n",
    "    size=1000,  # Limit to 1000 samples for this tutorial\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 设定 DataLoader\n",
    "\n",
    "DataLoader 将处理批量、打乱和加载数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.sampler import RandomSampler\n",
    "\n",
    "# Create a DataLoader for our training dataset\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset,\n",
    "    collate_fn=train_dataset.get_collator(),  # Custom collate function for our dataset\n",
    "    sampler=RandomSampler(train_dataset),     # Randomly sample data\n",
    "    batch_size=1,                             # Process one sample at a time\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练循环\n",
    "\n",
    "现在我们将对模型进行几次微调。我们在每个epoch后保存模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from collections import deque\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "progress_bar = tqdm(range(3*len(train_dataloader)), desc=\"Training for 1/3 epochs...\")\n",
    "losses = deque(maxlen=100)\n",
    "os.makedirs('./output', exist_ok=True)\n",
    "\n",
    "for epoch in range(3):\n",
    "    progress_bar.set_description(f\"Training for {epoch+1}/3 epochs...\")\n",
    "    for batch in train_dataloader:\n",
    "        batch.pop('meta_info')\n",
    "        model.train()\n",
    "        loss = model(**batch)['loss']\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.item())\n",
    "        progress_bar.update(1)\n",
    "        progress_bar.set_postfix(loss=np.mean(losses))\n",
    "\n",
    "    # Save the model after each epoch\n",
    "    model.save_pretrained('./output')\n",
    "    tokenizer.save_pretrained('./output')\n",
    "    processor.save_pretrained('./output')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完整版代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "from collections import deque\n",
    "import numpy as np\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.sampler import RandomSampler\n",
    "from torch.optim import AdamW\n",
    "\n",
    "from align_anything.models.pretrained_model import load_pretrained_models\n",
    "from align_anything.datasets.text_image_to_text import SupervisedDataset\n",
    "from align_anything.configs.template import ChatTemplate\n",
    "from align_anything.utils.multi_process import get_current_device\n",
    "\n",
    "\n",
    "model, tokenizer, processor = load_pretrained_models(\n",
    "    \"/path/to/llava-1.5-7b-hf\",\n",
    "    model_max_length=4096,\n",
    "    padding_side='right',\n",
    "    trust_remote_code=True,\n",
    "    modality=['image'],\n",
    ")\n",
    "\n",
    "model = model.to(get_current_device())\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-5)\n",
    "\n",
    "train_template = ChatTemplate(\n",
    "    formatter=processor,\n",
    "    template=\"AA_TI2T\",\n",
    ")\n",
    "\n",
    "train_dataset = SupervisedDataset(\n",
    "    path=\"/path/to/Align-Anything-TI2T-Instruction-100K\",  # Replace with your dataset path\n",
    "    template=train_template,\n",
    "    tokenizer=tokenizer,\n",
    "    processor=processor,\n",
    "    split=\"train\",\n",
    "    size=1000,  # Limit to 1000 samples for this tutorial\n",
    ")\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset,\n",
    "    collate_fn=train_dataset.get_collator(),\n",
    "    sampler=RandomSampler(train_dataset),\n",
    "    batch_size=1,\n",
    ")\n",
    "\n",
    "progress_bar = tqdm(range(3*len(train_dataloader)), desc=\"Training for 1/3 epochs...\")\n",
    "losses = deque(maxlen=100)\n",
    "os.makedirs('./output', exist_ok=True)\n",
    "\n",
    "for epoch in range(3):\n",
    "    progress_bar.set_description(f\"Training for {epoch+1}/3 epochs...\")\n",
    "    for batch in train_dataloader:\n",
    "        batch.pop('meta_info')\n",
    "        model.train()\n",
    "        loss = model(**batch)['loss']\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.item())\n",
    "        progress_bar.update(1)\n",
    "        progress_bar.set_postfix(loss=np.mean(losses))\n",
    "\n",
    "    # Save the model after each epoch\n",
    "    model.save_pretrained('./output')\n",
    "    tokenizer.save_pretrained('./output')\n",
    "    processor.save_pretrained('./output')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jy-align",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
